# -*- coding: utf-8 -*-
"""
Created on Fri Apr 15 14:22:23 2022


"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
import pickle
import os
from tqdm import tqdm
from models import PhyCRNet,MaskedMSELoss,MaskedL1Loss,weightMSELoss
from data_utils import PrepareConcDataset
from utils import plot_conc,read_logs,count_weight
from ies_utils import ies_main,cnn2gnn5,get_input_padding,plot_ensemble_output,plot_convergence,pad_ensemble2input,get_nonzero_matrix,get_sense_well,plot_pj_ratio
from shap_utils import get_shap_value,plot_shap_sig,get_topk_array,process_shap


seed=1000
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
set_seed(seed)

#%%
args={}
args['cuda'] = torch.cuda.is_available()
#if cuda is available use cuda
if args['cuda']:
    device = torch.device('cuda')


#%%
loot_log_path='./model/'
# dirs={
#       "log_dirs":log_dirs,
#       "model_dirs":model_dirs,
#       "args_dirs":args_dirs,
#       "events_dirs":events_dirs,
#       "orders":orders
#       }

# index=4
dirs=read_logs(loot_log_path)

#%%
for index in range(len(dirs['orders'])):
        
    with open(dirs['args_dirs'][index],'rb') as f:
        args=pickle.load(f)
        
    train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
    test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
    mmin_list,dev_list,conc_name,edge_list,count,node_coord,\
    time_scaler,scaler_Q,scaler_H,tot_uranium=PrepareConcDataset(args['time_window'],args['res'],
                                                                 args['pred_day'],args['buffer'],
                                                                 args['gap'],args['test_day'],
                                                                 args['val_ratio'],flag='plot')
      
    [batch_size,time_len,input_channels,h,w] = train_input.size()
    model_path=dirs['model_dirs'][index]
    
    model = PhyCRNet(args['input_dim'],
                     args['hidden_dim'],
                     args['output_dim'],
                     args['input_kernel_size'],
                     args['input_stride'],
                     args['padding_mode'],
                     args['num_layers'], 
                     args['dropout'],
                     args['warmup_updates'],
                     args['tot_updates'],
                     args['peak_lr'],
                     args['end_lr'],
                     args['power'],
                     args['weight_decay'],
                     edge_list,
                     scaler_Q,
                     mmin_list,
                     dev_list)
    
    model.load_state_dict(torch.load(model_path))
    model=model.to(device=device)
    count_weight(model)

    all_mae=plot_conc(args,model,dirs['log_dirs'][index],device,'sig')

    
    #%%
    model=model.cpu()
    loot_ies_path='./ies_result/'
    if not os.path.exists(loot_ies_path): 
        os.makedirs(loot_ies_path)
    
    ies_args={}
    ies_args['num_ensemble']=1000
    ies_args['init_lambd']=1
    ies_args['beta']=0.08
    ies_args['max_out_iter']=50
    ies_args['max_in_iter']=10
    ies_args['lambd_reduct']=0.8
    ies_args['lambd_incre']=2
    ies_args['do_tsvd']=1
    ies_args['tsvd_cut']=0.99
    ies_args['min_rn']=0.01
    ies_args['noise']=0.
    ies_args['max_lambd']=1e2
    ies_args['inverse_day']=7
    ies_args['num_samples']=200
    ies_args['topk']=5
    ies_args['std_diffuse']=1
    topk=ies_args['topk']
    ies_args['inject_lower_bound']=(-425-scaler_Q.mean_)/scaler_Q.scale_#90% quantile
    ies_args['inject_upper_bound']=(0-scaler_Q.mean_)/scaler_Q.scale_
    ies_args['pump_lower_bound']=(0-scaler_Q.mean_)/scaler_Q.scale_
    ies_args['pump_upper_bound']=(261-scaler_Q.mean_)/scaler_Q.scale_#90% quantile
    ies_args['H_lower_bound']=(0-scaler_H.mean_)/scaler_H.scale_
    ies_args['H_upper_bound']=(6.8-scaler_H.mean_)/scaler_H.scale_#90% quantile
    
    criterion1=np.arange(50,58)
    criterion2=2
    
    last_input=train_input[-1:,:,:,:,:].numpy()
    temp_input=get_input_padding(last_input,ies_args['inverse_day'])
    temp_input=torch.FloatTensor(temp_input)
    temp_target=model(temp_input)
    real_uranium=tot_uranium[-60:][:7]
    
    log_dir=dirs['log_dirs'][index]
    
    with open(log_dir+'/ies_args.pkl','wb') as f:
        pickle.dump(ies_args,f)


    get_shap_value(model,train_input,temp_input,ies_args['num_samples'],log_dir)
    
    peak_eg_values,peak_eg_values_var=process_shap(log_dir,time_len,input_channels,h,w,count,edge_list)

    
    model.load_state_dict(torch.load(model_path))
    loot_figure_path='./ies_result/'
    log_dir=dirs['log_dirs'][index][8:]
    file_type='svg'
    folder_name='test_sig'+'_'+file_type
    
    for peak_date in range(ies_args['inverse_day']):
        
        plot_shap_sig(edge_list,node_coord,loot_figure_path,log_dir,folder_name,file_type,peak_date,topk,
                      input_channels,time_len,peak_eg_values,peak_eg_values_var)
        
# '''
    #%%
    
    tot_sense_well=get_sense_well(peak_eg_values,edge_list,ies_args['topk'],)
        
    temp_input=temp_input.detach().numpy()
    temp_target=temp_target.detach().numpy()
    base_target=temp_target.copy()
    last_input=cnn2gnn5(last_input,edge_list)
    temp_input=cnn2gnn5(temp_input,edge_list)
    

    test_input=cnn2gnn5(test_input,edge_list)
    test_input=test_input*scaler_Q.scale_+scaler_Q.mean_
    real_pump=test_input[6:7,:,1,:20].sum(axis=-1)
    real_inject=test_input[6:7,:,1,20:].sum(axis=-1)
    real_pj_ratio=abs(real_pump/real_inject)
    
    inverse_flux_well=tot_sense_well[:,:5,0]
    
    
    #%%ies
    tot_pj_ratio=[]
    for i in tqdm(range(len(criterion1))):
        temp_target=np.ones_like(temp_target)*criterion1[i]
        
        
        ensemble,ensemble_output,ensemble_input,objs,lambds=ies_main(ies_args,model,edge_list,count,
                                                                      mmin_list,dev_list,scaler_Q,scaler_H,inverse_flux_well,
                                                                      tot_sense_well,last_input,temp_input,temp_target)
        
        with open(loot_figure_path+log_dir+f'/ensemble_list{criterion1[i]}.pkl','wb') as f:
            pickle.dump([ensemble,ensemble_output,ensemble_input,objs,lambds],f)
    
    
        #%%
        with open(loot_figure_path+log_dir+f'/ensemble_list{criterion1[i]}.pkl','rb') as f:
            [ensemble,ensemble_output,ensemble_input,objs,lambds]=pickle.load(f)
        
        #%%
        plot_ensemble_output(ensemble_output,temp_target,base_target,real_uranium,loot_figure_path,log_dir,criterion1[i],i)
        plot_convergence(objs,lambds,loot_figure_path,log_dir,criterion1[i])
        
        ensemble_mean=ensemble[-1,:].reshape(ies_args['inverse_day'],tot_sense_well.shape[1])
        
        inversed_flux=ensemble_mean[:,:5]
        inversed_H=ensemble_mean[:,5:]
        
        inversed_flux=scaler_Q.inverse_transform(inversed_flux.reshape(-1,1)).reshape(inversed_flux.shape)
        inversed_H=scaler_H.inverse_transform(inversed_H.reshape(-1,1)).reshape(inversed_H.shape)
        
        ensemble_mean=cnn2gnn5(ensemble_input[-1],edge_list)[:,:,1,:]*scaler_Q.scale_+scaler_Q.mean_
        en_pump=ensemble_mean[:,:,:20].sum(axis=-1)
        en_inject=ensemble_mean[:,:,20:].sum(axis=-1)
        
        en_pj_ratio=abs(en_pump/en_inject)
        tot_pj_ratio.append(en_pj_ratio)
    plot_pj_ratio(ies_args,tot_pj_ratio,real_pj_ratio,loot_figure_path,log_dir,criterion1)
    break
# '''   




